{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.data.all import *\n",
    "from fastai.optimizer import *\n",
    "from fastai.learner import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp metrics\n",
    "# default_cls_lvl 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metrics\n",
    "\n",
    "> Definition of the metrics that can be used in training models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Core metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is where the function that converts scikit-learn metrics to fastai metrics is defined. You should skip this section unless you want to know all about the internals of fastai."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "import sklearn.metrics as skm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "import scipy.stats as scs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export torch_core\n",
    "def flatten_check(inp, targ):\n",
    "    \"Check that `out` and `targ` have the same number of elements and flatten them.\"\n",
    "    inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)\n",
    "    test_eq(len(inp), len(targ))\n",
    "    return inp,targ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2 = torch.randn(5,4),torch.randn(20)\n",
    "x1,x2 = flatten_check(x1,x2)\n",
    "test_eq(x1.shape, [20])\n",
    "test_eq(x2.shape, [20])\n",
    "x1,x2 = torch.randn(5,4),torch.randn(21)\n",
    "test_fail(lambda: flatten_check(x1,x2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "mk_class('ActivationType', **{o:o.lower() for o in ['No', 'Sigmoid', 'Softmax', 'BinarySoftmax']},\n",
    "         doc=\"All possible activation classes for `AccumMetric\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AccumMetric(Metric):\n",
    "    \"Stores predictions and targets on CPU in accumulate to perform final calculations with `func`.\"\n",
    "    def __init__(self, func, dim_argmax=None, activation=ActivationType.No, thresh=None, to_np=False,\n",
    "                 invert_arg=False, flatten=True, **kwargs):\n",
    "        store_attr('func,dim_argmax,activation,thresh,flatten')\n",
    "        self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs\n",
    "\n",
    "    def reset(self):\n",
    "        \"Clear all targs and preds\"\n",
    "        self.targs,self.preds = [],[]\n",
    "\n",
    "    def accumulate(self, learn):\n",
    "        \"Store targs and preds from `learn`, using activation function and argmax as appropriate\"\n",
    "        pred = learn.pred\n",
    "        if self.activation in [ActivationType.Softmax, ActivationType.BinarySoftmax]:\n",
    "            pred = F.softmax(pred, dim=self.dim_argmax)\n",
    "            if self.activation == ActivationType.BinarySoftmax: pred = pred[:, -1]\n",
    "        elif self.activation == ActivationType.Sigmoid: pred = torch.sigmoid(pred)\n",
    "        elif self.dim_argmax: pred = pred.argmax(dim=self.dim_argmax)\n",
    "        if self.thresh:  pred = (pred >= self.thresh)\n",
    "        self.accum_values(pred,learn.y,learn)\n",
    "\n",
    "    def accum_values(self, preds, targs,learn=None):\n",
    "        \"Store targs and preds\"\n",
    "        to_d = learn.to_detach if learn is not None else to_detach\n",
    "        preds,targs = to_d(preds),to_d(targs)\n",
    "        if self.flatten: preds,targs = flatten_check(preds,targs)\n",
    "        self.preds.append(preds)\n",
    "        self.targs.append(targs)\n",
    "\n",
    "    def __call__(self, preds, targs):\n",
    "        \"Calculate metric on one batch of data\"\n",
    "        self.reset()\n",
    "        self.accum_values(preds,targs)\n",
    "        return self.value\n",
    "\n",
    "    @property\n",
    "    def value(self):\n",
    "        \"Value of the metric using accumulated preds and targs\"\n",
    "        if len(self.preds) == 0: return\n",
    "        preds,targs = torch.cat(self.preds),torch.cat(self.targs)\n",
    "        if self.to_np: preds,targs = preds.numpy(),targs.numpy()\n",
    "        return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)\n",
    "\n",
    "    @property\n",
    "    def name(self):  return self.func.func.__name__ if hasattr(self.func, 'func') else  self.func.__name__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`func` is only applied to the accumulated predictions/targets when the `value` attribute is asked for (so at the end of a validation/training phase, in use with `Learner` and its `Recorder`).The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).\n",
    "\n",
    "For classification problems with single label, predictions need to be transformed with a softmax then an argmax before being compared to the targets. Since a softmax doesn't change the order of the numbers, we can just apply the argmax. Pass along `dim_argmax` to have this done by `AccumMetric` (usually -1 will work pretty well). If you need to pass to your metrics the probabilities and not the predictions, use `softmax=True`.\n",
    "\n",
    "For classification problems with multiple labels, or if your targets are one-hot encoded, predictions may need to pass through a sigmoid (if it wasn't included in your model) then be compared to a given threshold (to decide between 0 and 1), this is done by `AccumMetric` if you pass `sigmoid=True` and/or a value for `thresh`.\n",
    "\n",
    "If you want to use a metric function sklearn.metrics, you will need to convert predictions and labels to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from us, so you will need to pass `invert_arg=True` to make `AccumMetric` do the inversion for you."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#For testing: a fake learner and a metric that isn't an average\n",
    "@delegates()\n",
    "class TstLearner(Learner):\n",
    "    def __init__(self,dls=None,model=None,**kwargs): self.pred,self.xb,self.yb = None,None,None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _l2_mean(x,y): return torch.sqrt((x.float()-y.float()).pow(2).mean())\n",
    "\n",
    "#Go through a fake cycle with various batch sizes and computes the value of met\n",
    "def compute_val(met, x1, x2):\n",
    "    met.reset()\n",
    "    vals = [0,6,15,20]\n",
    "    learn = TstLearner()\n",
    "    for i in range(3):\n",
    "        learn.pred,learn.yb = x1[vals[i]:vals[i+1]],(x2[vals[i]:vals[i+1]],)\n",
    "        met.accumulate(learn)\n",
    "    return met.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n",
    "tst = AccumMetric(_l2_mean)\n",
    "test_close(compute_val(tst, x1, x2), _l2_mean(x1, x2))\n",
    "test_eq(torch.cat(tst.preds), x1.view(-1))\n",
    "test_eq(torch.cat(tst.targs), x2.view(-1))\n",
    "\n",
    "#test argmax\n",
    "x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))\n",
    "tst = AccumMetric(_l2_mean, dim_argmax=-1)\n",
    "test_close(compute_val(tst, x1, x2), _l2_mean(x1.argmax(dim=-1), x2))\n",
    "\n",
    "#test thresh\n",
    "x1,x2 = torch.randn(20,5),torch.randint(0, 2, (20,5)).bool()\n",
    "tst = AccumMetric(_l2_mean, thresh=0.5)\n",
    "test_close(compute_val(tst, x1, x2), _l2_mean((x1 >= 0.5), x2))\n",
    "\n",
    "#test sigmoid\n",
    "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n",
    "tst = AccumMetric(_l2_mean, activation=ActivationType.Sigmoid)\n",
    "test_close(compute_val(tst, x1, x2), _l2_mean(torch.sigmoid(x1), x2))\n",
    "\n",
    "#test to_np\n",
    "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n",
    "tst = AccumMetric(lambda x,y: isinstance(x, np.ndarray) and isinstance(y, np.ndarray), to_np=True)\n",
    "assert compute_val(tst, x1, x2)\n",
    "\n",
    "#test invert_arg\n",
    "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n",
    "tst = AccumMetric(lambda x,y: torch.sqrt(x.pow(2).mean()))\n",
    "test_close(compute_val(tst, x1, x2), torch.sqrt(x1.pow(2).mean()))\n",
    "tst = AccumMetric(lambda x,y: torch.sqrt(x.pow(2).mean()), invert_arg=True)\n",
    "test_close(compute_val(tst, x1, x2), torch.sqrt(x2.pow(2).mean()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "def _l2_mean(x,y): return torch.sqrt((x.argmax(dim=-1).float()-y.float()).pow(2).mean())\n",
    "x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))\n",
    "tst = AccumMetric(_l2_mean, dim_argmax=-1, flatten=False, activation=ActivationType.Softmax)\n",
    "test_close(compute_val(tst, x1, x2), _l2_mean(F.softmax(x1, dim=-1), x2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def skm_to_fastai(func, is_class=True, thresh=None, axis=-1, activation=None, **kwargs):\n",
    "    \"Convert `func` from sklearn.metrics to a fastai metric\"\n",
    "    dim_argmax = axis if is_class and thresh is None else None\n",
    "    if activation is None:\n",
    "        activation = ActivationType.Sigmoid if (is_class and thresh is not None) else ActivationType.No\n",
    "    return AccumMetric(func, dim_argmax=dim_argmax, activation=activation, thresh=thresh,\n",
    "                       to_np=True, invert_arg=True, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the quickest way to use a scikit-learn metric in a fastai training loop. `is_class` indicates if you are in a classification problem or not. In this case:\n",
    "- leaving `thresh` to `None` indicates it's a single-label classification problem and predictions will pass through an argmax over `axis` before being compared to the targets\n",
    "- setting a value for `thresh` indicates it's a multi-label classification problem and predictions will pass through a sigmoid (can be deactivated with `sigmoid=False`) and be compared to `thresh` before being compared to the targets\n",
    "\n",
    "If `is_class=False`, it indicates you are in a regression problem, and predictions are compared to the targets without being modified. In all cases, `kwargs` are extra keyword arguments passed to `func`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_single = skm_to_fastai(skm.precision_score)\n",
    "x1,x2 = torch.randn(20,2),torch.randint(0, 2, (20,))\n",
    "test_close(compute_val(tst_single, x1, x2), skm.precision_score(x2, x1.argmax(dim=-1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2)\n",
    "x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))\n",
    "test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, torch.sigmoid(x1) >= 0.2))\n",
    "\n",
    "tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2, activation=ActivationType.No)\n",
    "x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))\n",
    "test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, x1 >= 0.2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_reg = skm_to_fastai(skm.r2_score, is_class=False)\n",
    "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n",
    "test_close(compute_val(tst_reg, x1, x2), skm.r2_score(x2.view(-1), x1.view(-1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_close(tst_reg(x1, x2), skm.r2_score(x2.view(-1), x1.view(-1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def optim_metric(f, argname, bounds, tol=0.01, do_neg=True, get_x=False):\n",
    "    \"Replace metric `f` with a version that optimizes argument `argname`\"\n",
    "    def _f(preds, targs):\n",
    "        def minfunc(x):\n",
    "            kwargs = {argname:x}\n",
    "            res = f(preds, targs, **kwargs)\n",
    "            return -res if do_neg else res\n",
    "        optres = scipy.optimize.minimize_scalar(minfunc, bounds=bounds, method='bounded',\n",
    "                                                options={'xatol':0.01})\n",
    "        fun = -optres.fun if do_neg else optres.fun\n",
    "        return (fun,optres.x) if get_x else fun\n",
    "    _f.__name__ = f'opt_{f.__name__}'\n",
    "    return _f"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single-label classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Warning: All functions defined in this section are intended for single-label classification and targets that are not one-hot encoded. For multi-label problems or one-hot encoded targets, use the version suffixed with multi."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Warning: Many metrics in fastai are thin wrappers around sklearn functionality. However, sklearn metrics can handle python list strings, amongst other things, whereas fastai metrics work with PyTorch, and thus require tensors. The arguments that are passed to metrics are after all transformations, such as categories being converted to indices, have occurred. This means that when you pass a label of a metric, for instance, that you must pass indices, not strings. This can be converted with `vocab.map_obj`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def accuracy(inp, targ, axis=-1):\n",
    "    \"Compute accuracy with `targ` when `pred` is bs * n_classes\"\n",
    "    pred,targ = flatten_check(inp.argmax(dim=axis), targ)\n",
    "    return (pred == targ).float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#For testing\n",
    "def change_targ(targ, n, c):\n",
    "    idx = torch.randperm(len(targ))[:n]\n",
    "    res = targ.clone()\n",
    "    for i in idx: res[i] = (res[i]+random.randint(1,c-1))%c\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(4,5)\n",
    "y = x.argmax(dim=1)\n",
    "test_eq(accuracy(x,y), 1)\n",
    "y1 = change_targ(y, 2, 5)\n",
    "test_eq(accuracy(x,y1), 0.5)\n",
    "test_eq(accuracy(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def error_rate(inp, targ, axis=-1):\n",
    "    \"1 - `accuracy`\"\n",
    "    return 1 - accuracy(inp, targ, axis=axis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(4,5)\n",
    "y = x.argmax(dim=1)\n",
    "test_eq(error_rate(x,y), 0)\n",
    "y1 = change_targ(y, 2, 5)\n",
    "test_eq(error_rate(x,y1), 0.5)\n",
    "test_eq(error_rate(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def top_k_accuracy(inp, targ, k=5, axis=-1):\n",
    "    \"Computes the Top-k accuracy (`targ` is in the top `k` predictions of `inp`)\"\n",
    "    inp = inp.topk(k=k, dim=axis)[1]\n",
    "    targ = targ.unsqueeze(dim=axis).expand_as(inp)\n",
    "    return (inp == targ).sum(dim=-1).float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(6,5)\n",
    "y = torch.arange(0,6)\n",
    "test_eq(top_k_accuracy(x[:5],y[:5]), 1)\n",
    "test_eq(top_k_accuracy(x, y), 5/6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def APScoreBinary(axis=-1, average='macro', pos_label=1, sample_weight=None):\n",
    "    \"Average Precision for single-label binary classification problems\"\n",
    "    return skm_to_fastai(skm.average_precision_score, axis=axis, activation=ActivationType.BinarySoftmax,\n",
    "                         average=average, pos_label=pos_label, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def BalancedAccuracy(axis=-1, sample_weight=None, adjusted=False):\n",
    "    \"Balanced Accuracy for single-label binary classification problems\"\n",
    "    return skm_to_fastai(skm.balanced_accuracy_score, axis=axis,\n",
    "                         sample_weight=sample_weight, adjusted=adjusted)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html#sklearn.metrics.balanced_accuracy_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def BrierScore(axis=-1, sample_weight=None, pos_label=None):\n",
    "    \"Brier score for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.brier_score_loss, axis=axis,\n",
    "                         sample_weight=sample_weight, pos_label=pos_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def CohenKappa(axis=-1, labels=None, weights=None, sample_weight=None):\n",
    "    \"Cohen kappa for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.cohen_kappa_score, axis=axis, labels=labels, weights=weights,\n",
    "                         sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html#sklearn.metrics.cohen_kappa_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def F1Score(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n",
    "    \"F1 score for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.f1_score, axis=axis,\n",
    "                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def FBeta(beta, axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n",
    "    \"FBeta score with `beta` for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.fbeta_score, axis=axis,\n",
    "                beta=beta, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html#sklearn.metrics.fbeta_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def HammingLoss(axis=-1, sample_weight=None):\n",
    "    \"Hamming loss for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.hamming_loss, axis=axis,\n",
    "                         sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html#sklearn.metrics.hamming_loss) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def Jaccard(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n",
    "    \"Jaccard score for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.jaccard_score, axis=axis,\n",
    "                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html#sklearn.metrics.jaccard_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def Precision(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n",
    "    \"Precision for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.precision_score, axis=axis,\n",
    "                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def Recall(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n",
    "    \"Recall for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.recall_score, axis=axis,\n",
    "                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='ovr'):\n",
    "    \"Area Under the Receiver Operating Characteristic Curve for single-label multiclass classification problems\"\n",
    "    assert multi_class in ['ovr', 'ovo']\n",
    "    return skm_to_fastai(skm.roc_auc_score, axis=axis, activation=ActivationType.Softmax, flatten=False,\n",
    "                         average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def RocAucBinary(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='raise'):\n",
    "    \"Area Under the Receiver Operating Characteristic Curve for single-label binary classification problems\"\n",
    "    return skm_to_fastai(skm.roc_auc_score, axis=axis, activation=ActivationType.BinarySoftmax,\n",
    "                         average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def MatthewsCorrCoef(sample_weight=None, **kwargs):\n",
    "    \"Matthews correlation coefficient for single-label classification problems\"\n",
    "    return skm_to_fastai(skm.matthews_corrcoef, sample_weight=sample_weight, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html#sklearn.metrics.matthews_corrcoef) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Perplexity(AvgLoss):\n",
    "    \"Perplexity (exponential of cross-entropy loss) for Language Models\"\n",
    "    @property\n",
    "    def value(self): return torch.exp(self.total/self.count) if self.count != 0 else None\n",
    "    @property\n",
    "    def name(self):  return \"perplexity\"\n",
    "\n",
    "perplexity = Perplexity()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))\n",
    "tst = perplexity\n",
    "tst.reset()\n",
    "vals = [0,6,15,20]\n",
    "learn = TstLearner()\n",
    "for i in range(3): \n",
    "    learn.yb = (x2[vals[i]:vals[i+1]],)\n",
    "    learn.loss = F.cross_entropy(x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]])\n",
    "    tst.accumulate(learn)\n",
    "test_close(tst.value, torch.exp(F.cross_entropy(x1,x2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-label classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):\n",
    "    \"Compute accuracy when `inp` and `targ` are the same size.\"\n",
    "    inp,targ = flatten_check(inp,targ)\n",
    "    if sigmoid: inp = inp.sigmoid()\n",
    "    return ((inp>thresh)==targ.bool()).float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#For testing\n",
    "def change_1h_targ(targ, n):\n",
    "    idx = torch.randperm(targ.numel())[:n]\n",
    "    res = targ.clone().view(-1)\n",
    "    for i in idx: res[i] = 1-res[i]\n",
    "    return res.view(targ.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(4,5)\n",
    "y = (torch.sigmoid(x) >= 0.5).byte()\n",
    "test_eq(accuracy_multi(x,y), 1)\n",
    "test_eq(accuracy_multi(x,1-y), 0)\n",
    "y1 = change_1h_targ(y, 5)\n",
    "test_eq(accuracy_multi(x,y1), 0.75)\n",
    "\n",
    "#Different thresh\n",
    "y = (torch.sigmoid(x) >= 0.2).byte()\n",
    "test_eq(accuracy_multi(x,y, thresh=0.2), 1)\n",
    "test_eq(accuracy_multi(x,1-y, thresh=0.2), 0)\n",
    "y1 = change_1h_targ(y, 5)\n",
    "test_eq(accuracy_multi(x,y1, thresh=0.2), 0.75)\n",
    "\n",
    "#No sigmoid\n",
    "y = (x >= 0.5).byte()\n",
    "test_eq(accuracy_multi(x,y, sigmoid=False), 1)\n",
    "test_eq(accuracy_multi(x,1-y, sigmoid=False), 0)\n",
    "y1 = change_1h_targ(y, 5)\n",
    "test_eq(accuracy_multi(x,y1, sigmoid=False), 0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def APScoreMulti(sigmoid=True, average='macro', pos_label=1, sample_weight=None):\n",
    "    \"Average Precision for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.average_precision_score, activation=activation, flatten=False,\n",
    "                         average=average, pos_label=pos_label, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def BrierScoreMulti(thresh=0.5, sigmoid=True, sample_weight=None, pos_label=None):\n",
    "    \"Brier score for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.brier_score_loss, thresh=thresh, activation=activation, flatten=False,\n",
    "                         sample_weight=sample_weight, pos_label=pos_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def F1ScoreMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n",
    "    \"F1 score for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.f1_score, thresh=thresh, activation=activation, flatten=False,\n",
    "                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def FBetaMulti(beta, thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n",
    "    \"FBeta score with `beta` for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.fbeta_score, thresh=thresh, activation=activation, flatten=False,\n",
    "                beta=beta, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html#sklearn.metrics.fbeta_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def HammingLossMulti(thresh=0.5, sigmoid=True, labels=None, sample_weight=None):\n",
    "    \"Hamming loss for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.hamming_loss, thresh=thresh, activation=activation, flatten=False,\n",
    "                         sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html#sklearn.metrics.hamming_loss) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def JaccardMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n",
    "    \"Jaccard score for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.jaccard_score, thresh=thresh, activation=activation, flatten=False,\n",
    "                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html#sklearn.metrics.jaccard_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def MatthewsCorrCoefMulti(thresh=0.5, sigmoid=True, sample_weight=None):\n",
    "    \"Matthews correlation coefficient for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.matthews_corrcoef, thresh=thresh, activation=activation, flatten=False, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html#sklearn.metrics.matthews_corrcoef) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def PrecisionMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n",
    "    \"Precision for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.precision_score, thresh=thresh, activation=activation, flatten=False,\n",
    "                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def RecallMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n",
    "    \"Recall for multi-label classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.recall_score, thresh=thresh, activation=activation, flatten=False,\n",
    "                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def RocAucMulti(sigmoid=True, average='macro', sample_weight=None, max_fpr=None):\n",
    "    \"Area Under the Receiver Operating Characteristic Curve for multi-label binary classification problems\"\n",
    "    activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n",
    "    return skm_to_fastai(skm.roc_auc_score, activation=activation, flatten=False,\n",
    "                         average=average, sample_weight=sample_weight, max_fpr=max_fpr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "roc_auc_metric = RocAucMulti(sigmoid=False)\n",
    "x,y = torch.tensor([np.arange(start=0, stop=0.2, step=0.04)]*20), torch.tensor([0, 0, 1, 1]).repeat(5)\n",
    "assert compute_val(roc_auc_metric, x, y) == 0.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def mse(inp,targ):\n",
    "    \"Mean squared error between `inp` and `targ`.\"\n",
    "    return F.mse_loss(*flatten_check(inp,targ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2 = torch.randn(4,5),torch.randn(4,5)\n",
    "test_close(mse(x1,x2), (x1-x2).pow(2).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _rmse(inp, targ): return torch.sqrt(F.mse_loss(inp, targ))\n",
    "rmse = AccumMetric(_rmse)\n",
    "rmse.__doc__ = \"Root mean squared error\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"rmse\" class=\"doc_header\"><code>rmse</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>rmse</code>(**`preds`**, **`targs`**)\n",
       "\n",
       "Root mean squared error"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(rmse, name=\"rmse\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n",
    "test_eq(compute_val(rmse, x1, x2), torch.sqrt(F.mse_loss(x1,x2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def mae(inp,targ):\n",
    "    \"Mean absolute error between `inp` and `targ`.\"\n",
    "    inp,targ = flatten_check(inp,targ)\n",
    "    return torch.abs(inp - targ).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2 = torch.randn(4,5),torch.randn(4,5)\n",
    "test_eq(mae(x1,x2), torch.abs(x1-x2).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def msle(inp, targ):\n",
    "    \"Mean squared logarithmic error between `inp` and `targ`.\"\n",
    "    inp,targ = flatten_check(inp,targ)\n",
    "    return F.mse_loss(torch.log(1 + inp), torch.log(1 + targ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2 = torch.randn(4,5),torch.randn(4,5)\n",
    "x1,x2 = torch.relu(x1),torch.relu(x2)\n",
    "test_close(msle(x1,x2), (torch.log(x1+1)-torch.log(x2+1)).pow(2).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _exp_rmspe(inp,targ):\n",
    "    inp,targ = torch.exp(inp),torch.exp(targ)\n",
    "    return torch.sqrt(((targ - inp)/targ).pow(2).mean())\n",
    "exp_rmspe = AccumMetric(_exp_rmspe)\n",
    "exp_rmspe.__doc__ = \"Root mean square percentage error of the exponential of  predictions and targets\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"exp_rmspe\" class=\"doc_header\"><code>exp_rmspe</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>exp_rmspe</code>(**`preds`**, **`targs`**)\n",
       "\n",
       "Root mean square percentage error of the exponential of  predictions and targets"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(exp_rmspe, name=\"exp_rmspe\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n",
    "test_eq(compute_val(exp_rmspe, x1, x2), torch.sqrt((((torch.exp(x2) - torch.exp(x1))/torch.exp(x2))**2).mean()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def ExplainedVariance(sample_weight=None):\n",
    "    \"Explained variance between predictions and targets\"\n",
    "    return skm_to_fastai(skm.explained_variance_score, is_class=False, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.explained_variance_score.html#sklearn.metrics.explained_variance_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def R2Score(sample_weight=None):\n",
    "    \"R2 score between predictions and targets\"\n",
    "    return skm_to_fastai(skm.r2_score, is_class=False, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html#sklearn.metrics.r2_score) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(AccumMetric)\n",
    "def PearsonCorrCoef(dim_argmax=None, **kwargs):\n",
    "    \"Pearson correlation coefficient for regression problem\"\n",
    "    def pearsonr(x,y): return scs.pearsonr(x,y)[0]\n",
    "    return AccumMetric(pearsonr, invert_arg=False, dim_argmax=dim_argmax, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scipy documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html?highlight=pearson#scipy.stats.pearsonr) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randint(-999, 999,(20,))\n",
    "y = torch.randint(-999, 999,(20,))\n",
    "test_eq(compute_val(PearsonCorrCoef(), x, y), scs.pearsonr(x.view(-1), y.view(-1))[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(AccumMetric)\n",
    "def SpearmanCorrCoef(dim_argmax=None, axis=0, nan_policy='propagate', **kwargs):\n",
    "    \"Spearman correlation coefficient for regression problem\"\n",
    "    def spearmanr(a,b=None,**kwargs): return scs.spearmanr(a,b,**kwargs)[0]\n",
    "    return AccumMetric(partial(spearmanr, axis=axis, nan_policy=nan_policy),\n",
    "                       invert_arg=False, dim_argmax=dim_argmax, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [scipy documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html?highlight=spearman#scipy.stats.spearmanr) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randint(-999, 999,(20,))\n",
    "y = torch.randint(-999, 999,(20,))\n",
    "test_eq(compute_val(SpearmanCorrCoef(), x, y), scs.spearmanr(x.view(-1), y.view(-1))[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def foreground_acc(inp, targ, bkg_idx=0, axis=1):\n",
    "    \"Computes non-background accuracy for multiclass segmentation\"\n",
    "    targ = targ.squeeze(1)\n",
    "    mask = targ != bkg_idx\n",
    "    return (inp.argmax(dim=axis)[mask]==targ[mask]).float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(4,5,3,3)\n",
    "y = x.argmax(dim=1)[:,None]\n",
    "test_eq(foreground_acc(x,y), 1)\n",
    "y[0] = 0 #the 0s are ignored so we get the same value\n",
    "test_eq(foreground_acc(x,y), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Dice(Metric):\n",
    "    \"Dice coefficient metric for binary target in segmentation\"\n",
    "    def __init__(self, axis=1): self.axis = axis\n",
    "    def reset(self): self.inter,self.union = 0,0\n",
    "    def accumulate(self, learn):\n",
    "        pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y)\n",
    "        self.inter += (pred*targ).float().sum().item()\n",
    "        self.union += (pred+targ).float().sum().item()\n",
    "\n",
    "    @property\n",
    "    def value(self): return 2. * self.inter/self.union if self.union > 0 else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = torch.randn(20,2,3,3)\n",
    "x2 = torch.randint(0, 2, (20, 3, 3))\n",
    "pred = x1.argmax(1)\n",
    "inter = (pred*x2).float().sum().item()\n",
    "union = (pred+x2).float().sum().item()\n",
    "test_eq(compute_val(Dice(), x1, x2), 2*inter/union)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class DiceMulti(Metric):\n",
    "    \"Averaged Dice metric (Macro F1) for multiclass target in segmentation\"\n",
    "    def __init__(self, axis=1): self.axis = axis\n",
    "    def reset(self): self.inter,self.union = {},{}\n",
    "    def accumulate(self, learn):\n",
    "        pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y)\n",
    "        for c in range(learn.pred.shape[self.axis]):\n",
    "            p = torch.where(pred == c, 1, 0)\n",
    "            t = torch.where(targ == c, 1, 0)\n",
    "            c_inter = (p*t).float().sum().item()\n",
    "            c_union = (p+t).float().sum().item()\n",
    "            if c in self.inter:\n",
    "                self.inter[c] += c_inter\n",
    "                self.union[c] += c_union\n",
    "            else:\n",
    "                self.inter[c] = c_inter\n",
    "                self.union[c] = c_union\n",
    "\n",
    "    @property\n",
    "    def value(self):\n",
    "        binary_dice_scores = np.array([])\n",
    "        for c in self.inter:\n",
    "            binary_dice_scores = np.append(binary_dice_scores, 2.*self.inter[c]/self.union[c] if self.union[c] > 0 else np.nan)\n",
    "        return np.nanmean(binary_dice_scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The DiceMulti method implements the \"Averaged F1: arithmetic mean over harmonic means\" described in this publication: https://arxiv.org/pdf/1911.03347.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1a = torch.ones(20,1,1,1)\n",
    "x1b = torch.clone(x1a)*0.5\n",
    "x1c = torch.clone(x1a)*0.3\n",
    "x1 = torch.cat((x1a,x1b,x1c),dim=1)   # Prediction: 20xClass0\n",
    "x2 = torch.zeros(20,1,1)              # Target: 20xClass0\n",
    "test_eq(compute_val(DiceMulti(), x1, x2), 1.)\n",
    "\n",
    "x2 = torch.ones(20,1,1)               # Target: 20xClass1\n",
    "test_eq(compute_val(DiceMulti(), x1, x2), 0.)\n",
    "\n",
    "x2a = torch.zeros(10,1,1)\n",
    "x2b = torch.ones(5,1,1)\n",
    "x2c = torch.ones(5,1,1) * 2\n",
    "x2 = torch.cat((x2a,x2b,x2c),dim=0)   # Target: 10xClass0, 5xClass1, 5xClass2\n",
    "dice1 = (2*10)/(2*10+10)              # Dice: 2*TP/(2*TP+FP+FN)\n",
    "dice2 = 0\n",
    "dice3 = 0\n",
    "test_eq(compute_val(DiceMulti(), x1, x2), (dice1+dice2+dice3)/3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class JaccardCoeff(Dice):\n",
    "    \"Implementation of the Jaccard coefficient that is lighter in RAM\"\n",
    "    @property\n",
    "    def value(self): return self.inter/(self.union-self.inter) if self.union > 0 else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = torch.randn(20,2,3,3)\n",
    "x2 = torch.randint(0, 2, (20, 3, 3))\n",
    "pred = x1.argmax(1)\n",
    "inter = (pred*x2).float().sum().item()\n",
    "union = (pred+x2).float().sum().item()\n",
    "test_eq(compute_val(JaccardCoeff(), x1, x2), inter/(union-inter))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CorpusBLEUMetric(Metric):\n",
    "    def __init__(self, vocab_sz=5000, axis=-1):\n",
    "        \"BLEU Metric calculated over the validation corpus\"\n",
    "        self.metric_name = 'CorpusBLEU'\n",
    "        self.axis, self.vocab_sz = axis, vocab_sz\n",
    "        self.pred_len,self.targ_len,self.samp_idx,self.corrects,self.counts, = 0,0,0,[0]*4,[0]*4\n",
    "\n",
    "    def reset(self):\n",
    "        self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4\n",
    "\n",
    "    class NGram():\n",
    "        def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n\n",
    "        def __eq__(self, other):\n",
    "            if len(self.ngram) != len(other.ngram): return False\n",
    "            return np.all(np.array(self.ngram) == np.array(other.ngram))\n",
    "        def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))\n",
    "\n",
    "    def get_grams(self, x, n, max_n=5000):\n",
    "        return x if n==1 else [self.NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]\n",
    "\n",
    "    def get_correct_ngrams(self, pred, targ, n, max_n=5000):\n",
    "        pred_grams,targ_grams = self.get_grams(pred, n, max_n=max_n),self.get_grams(targ, n, max_n=max_n)\n",
    "        pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)\n",
    "        return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)\n",
    "\n",
    "    def accumulate(self, learn):\n",
    "        if learn.training: return None\n",
    "        else:\n",
    "            last_output = learn.pred.argmax(dim=self.axis)\n",
    "            last_target = learn.y\n",
    "            for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):\n",
    "                self.pred_len += len(pred)\n",
    "                self.targ_len += len(targ)\n",
    "                smooth_mteval = 1\n",
    "                for i in range(4):\n",
    "                    c,t = self.get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)\n",
    "                    if c == 0:\n",
    "                        smooth_mteval *= 2\n",
    "                        c = 1 / smooth_mteval    # exp smoothing, method 3 from http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf\n",
    "                    self.corrects[i] += c\n",
    "                    self.counts[i]   += t\n",
    "\n",
    "    @property\n",
    "    def value(self):\n",
    "        if self.counts == 0: return None\n",
    "        elif max(self.corrects) == 0: return 0.0\n",
    "        else:\n",
    "            precs = [c/t for c,t in zip(self.corrects,self.counts)]\n",
    "            len_penalty = math.exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1\n",
    "            return len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_vcb_emb(pred, targ):\n",
    "    # create vocab \"embedding\" for predictions\n",
    "    vcb_sz = max(torch.unique(torch.cat([pred, targ])))+1\n",
    "    pred_emb=torch.zeros(pred.size()[0], pred.size()[1] ,vcb_sz)\n",
    "    for i,v in enumerate(pred):\n",
    "        pred_emb[i].scatter_(1, v.view(len(v),1),1)\n",
    "    return pred_emb\n",
    "\n",
    "def compute_bleu_val(met, x1, x2):\n",
    "    met.reset()\n",
    "    learn = TstLearner()\n",
    "    learn.training=False    \n",
    "    for i in range(len(x1)): \n",
    "        learn.pred,learn.yb = x1, (x2,)\n",
    "        met.accumulate(learn)\n",
    "    return met.value\n",
    "\n",
    "targ = torch.tensor([[1,2,3,4,5,6,1,7,8]]) \n",
    "pred = torch.tensor([[1,9,3,4,5,6,1,10,8]])\n",
    "pred_emb = create_vcb_emb(pred, targ)\n",
    "test_close(compute_bleu_val(CorpusBLEUMetric(), pred_emb, targ), 0.48549)\n",
    "\n",
    "targ = torch.tensor([[1,2,3,4,5,6,1,7,8],[1,2,3,4,5,6,1,7,8]]) \n",
    "pred = torch.tensor([[1,9,3,4,5,6,1,10,8],[1,9,3,4,5,6,1,10,8]])\n",
    "pred_emb = create_vcb_emb(pred, targ)\n",
    "test_close(compute_bleu_val(CorpusBLEUMetric(), pred_emb, targ), 0.48549)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The BLEU metric was introduced in [this article](https://www.aclweb.org/anthology/P02-1040) to come up with a way to evaluate the performance of translation models. It's based on the precision of n-grams in your prediction compared to your target. See the [fastai NLP course BLEU notebook](https://github.com/fastai/course-nlp/blob/master/bleu_metric.ipynb) for a more detailed description of BLEU.\n",
    "\n",
    "The smoothing used in the precision calculation is the same as in [SacreBLEU](https://github.com/mjpost/sacrebleu/blob/32c54cdd0dfd6a9fadd5805f2ea189ac0df63907/sacrebleu/sacrebleu.py#L540-L542), which in turn is \"method 3\" from the [Chen & Cherry, 2014](http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf) paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LossMetrics -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class LossMetric(AvgMetric):\n",
    "    \"Create a metric from `loss_func.attr` named `nm`\"\n",
    "    def __init__(self, attr, nm=None): store_attr('attr,nm')\n",
    "    def accumulate(self, learn):\n",
    "        bs = find_bs(learn.yb)\n",
    "        self.total += learn.to_detach(getattr(learn.loss_func, self.attr, 0))*bs\n",
    "        self.count += bs\n",
    "\n",
    "    @property\n",
    "    def name(self): return self.attr if self.nm is None else self.nm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def LossMetrics(attrs, nms=None):\n",
    "    \"List of `LossMetric` for each of `attrs` and `nms`\"\n",
    "    if isinstance(attrs, str): attrs = attrs.split(',')\n",
    "    nms = attrs if nms is None else nms.split(',') if isinstance(nms, str) else nms\n",
    "    return [LossMetric(a, n) for a,n in zip(attrs,nms)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CombineL1L2(Module):\n",
    "    def forward(self, out, targ):\n",
    "        self.l1 = F.l1_loss(out, targ)\n",
    "        self.l2 = F.mse_loss(out, targ)\n",
    "        return self.l1+self.l2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 16.63826560974121, 14.52301025390625, 3.3376736640930176, 11.18533706665039, '00:00']\n",
      "[1, 14.520439147949219, 10.179483413696289, 2.7222838401794434, 7.457200050354004, '00:00']\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner(metrics=LossMetrics('l1,l2'))\n",
    "learn.loss_func = CombineL1L2()\n",
    "learn.fit(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
